/***************************************************************************************************
 * Copyright (c) 2017-2021, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Unit tests for thread-level GEMM
*/

#include <fstream>

#include "../../common/cutlass_unit_test.h"

#include "cutlass/aligned_buffer.h"
#include "cutlass/half.h"

#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"

#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/tensor_fill.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace test {
namespace epilogue {
namespace threadblock {

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename TileIterator>
__global__ void kernel_store_iterator(typename TileIterator::Params params,
                                      typename TileIterator::TensorRef ref,
                                      cutlass::MatrixCoord extent) {
    TileIterator iterator(params, ref.data(), extent, threadIdx.x, {0, 0});

    typename TileIterator::Fragment fragment;

    CUTLASS_PRAGMA_NO_UNROLL
    for (int iter = 0; iter < TileIterator::ThreadMap::Count::kTile; ++iter) {
        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < TileIterator::Fragment::kElements; ++i) {
            typename TileIterator::Element tidx(iter + 1);
            fragment[i] = tidx;
        }

        iterator.store(fragment);

        ++iterator;
    }
}

/////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace threadblock
}  // namespace epilogue
}  // namespace test

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T, typename Layout>
static bool verify_footprint(cutlass::TensorView<T, Layout> view,
                             cutlass::MatrixCoord extent) {
    for (int r = 0; r < view.extent().row(); ++r) {
        for (int c = 0; c < view.extent().column(); ++c) {
            cutlass::MatrixCoord coord{r, c};
            bool within = coord < extent;
            if (within) {
                if (view.at(coord) == T(0)) {
                    return false;
                }
            } else {
                if (view.at(coord) != T(0)) {
                    return false;
                }
            }
        }
    }

    return true;
}

/////////////////////////////////////////////////////////////////////////////////////////////////

TEST(PredicatedTileIterator, tensor_op_64x64x32_64x64x8) {
    using Layout = cutlass::layout::RowMajor;
    using Element = int;

    static int const kElementsPerAccess =
            128 / cutlass::sizeof_bits<Element>::value;
    static int const kThreads = 32;

    //
    // The following tests were used to develop the OutputTileOptimalThreadMap
    // metaprogram. The definitions in the disabled blocks of code in this and
    // the following tests are hand-written quantities. They are expected to
    // match what is defined in the ThreadMap.
    //

#if 1
    using ThreadMap =
            cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<
                    cutlass::epilogue::threadblock::OutputTileShape<64, 8, 1, 1,
                                                                    1>,
                    cutlass::epilogue::threadblock::OutputTileShape<1, 8, 1, 1,
                                                                    8>,
                    kThreads, kElementsPerAccess,
                    cutlass::sizeof_bits<Element>::value>;

#else
    using InternalThreadMap =
            cutlass::transform::PitchLinearStripminedThreadMap<
                    cutlass::layout::PitchLinearShape<64, 64>, kThreads,
                    kElementsPerAccess>;

    using Shape =
            cutlass::epilogue::threadblock::OutputTileShape<64,  // column
                                                            8,   // row
                                                            1,   // group
                                                            1,   // cluster
                                                            1    // iterations
                                                            >;

    using Iterations =
            cutlass::epilogue::threadblock::OutputTileShape<1,  // column
                                                            4,  // row
                                                            1,  // group
                                                            1,  // cluster
                                                            1   // iterations
                                                            >;

    using Delta =
            cutlass::epilogue::threadblock::OutputTileShape<1,  // column
                                                            2,  // row
                                                            1,  // group
                                                            1,  // cluster
                                                            1   // iterations
                                                            >;

    using Count =
            cutlass::epilogue::threadblock::OutputTileShape<1,  // column
                                                            8,  // row
                                                            1,  // group
                                                            1,  // cluster
                                                            8   // iterations
                                                            >;

    using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap<
            InternalThreadMap, Shape, Iterations, Delta, Count>;
#endif

    using PredicatedTileIterator =
            cutlass::epilogue::threadblock::PredicatedTileIterator<ThreadMap,
                                                                   Element>;

    //
    // Initialize workspace
    //
    cutlass::MatrixCoord tensor_extent{64, 64};
    cutlass::MatrixCoord output_extent{62, 56};

    //
    // Configure parameters
    //

    cutlass::HostTensor<Element, Layout> host_tensor(tensor_extent);

    typename PredicatedTileIterator::Params iterator_params(
            host_tensor.layout());

    host_tensor.sync_device();

    //
    // Launch kernel
    //

    dim3 grid(1, 1);
    dim3 block(kThreads, 1);

    test::epilogue::threadblock::kernel_store_iterator<PredicatedTileIterator>
            <<<grid, block>>>(iterator_params, host_tensor.device_ref(),
                              output_extent);

    cudaError_t result = cudaDeviceSynchronize();
    ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result);

    //
    // Verify results
    //

    host_tensor.sync_host();

    bool passed = verify_footprint(host_tensor.host_view(), output_extent);
    EXPECT_TRUE(passed);

    if (!passed) {
        std::ofstream output("tensor_op_64x64x32_64x64x8.csv");
        output << host_tensor.host_view();
    }
}

/////////////////////////////////////////////////////////////////////////////////////////////////

TEST(PredicatedTileIterator, tensor_op_128x64x32_64x64x8) {
    using Layout = cutlass::layout::RowMajor;
    using Element = int;

    static int const kElementsPerAccess =
            128 / cutlass::sizeof_bits<Element>::value;
    static int const kThreads = 64;

#if 1

    using ThreadMap =
            cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<
                    cutlass::epilogue::threadblock::OutputTileShape<128, 8, 2,
                                                                    1, 1>,
                    cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1,
                                                                    8>,
                    kThreads, kElementsPerAccess,
                    cutlass::sizeof_bits<Element>::value>;

#else
    using InternalThreadMap =
            cutlass::transform::PitchLinearStripminedThreadMap<
                    cutlass::layout::PitchLinearShape<64, 128>, kThreads,
                    kElementsPerAccess>;

    using Shape = cutlass::epilogue::threadblock::OutputTileShape<64,  // column
                                                                  8,   // row
                                                                  2,   // group
                                                                  1,  // cluster
                                                                  8   // tile
                                                                  >;

    using Iterations =
            cutlass::epilogue::threadblock::OutputTileShape<1,  // column
                                                            2,  // row
                                                            2,  // group
                                                            1,  // cluster
                                                            1   // iterations
                                                            >;

    using Delta = cutlass::epilogue::threadblock::OutputTileShape<1,   // column
                                                                  4,   // row
                                                                  64,  // group
                                                                  1,  // cluster
                                                                  1   // tile
                                                                  >;

    using Count =
            cutlass::epilogue::threadblock::OutputTileShape<1,  // column
                                                            8,  // row
                                                            1,  // group
                                                            1,  // cluster
                                                            8   // iterations
                                                            >;

    using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap<
            InternalThreadMap, Shape, Iterations, Delta, Count>;
#endif

    using PredicatedTileIterator =
            cutlass::epilogue::threadblock::PredicatedTileIterator<ThreadMap,
                                                                   Element>;

    //
    // Initialize workspace
    //
    cutlass::MatrixCoord tensor_extent{128, 64};
    cutlass::MatrixCoord output_extent{125, 56};

    //
    // Configure parameters
    //

    cutlass::HostTensor<Element, Layout> host_tensor(tensor_extent);

    typename PredicatedTileIterator::Params iterator_params(
            host_tensor.layout());

    host_tensor.sync_device();

    //
    // Launch kernel
    //

    dim3 grid(1, 1);
    dim3 block(kThreads, 1);

    test::epilogue::threadblock::kernel_store_iterator<PredicatedTileIterator>
            <<<grid, block>>>(iterator_params, host_tensor.device_ref(),
                              output_extent);

    cudaError_t result = cudaDeviceSynchronize();
    ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result);

    //
    // Verify results
    //

    host_tensor.sync_host();

    bool passed = verify_footprint(host_tensor.host_view(), output_extent);
    EXPECT_TRUE(passed);

    if (!passed) {
        std::ofstream output("tensor_op_128x64x32_64x64x8.csv");
        output << host_tensor.host_view();
    }
}

/////////////////////////////////////////////////////////////////////////////////////////////////

/////////////////////////////////////////////////////////////////////////////////////////////////

TEST(PredicatedTileIterator, tensor_op_128x256x32_64x64x8) {
    using Layout = cutlass::layout::RowMajor;
    using Element = int;

    static int const kElementsPerAccess =
            128 / cutlass::sizeof_bits<Element>::value;
    static int const kThreads = 256;

#if 1

    using ThreadMap =
            cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<
                    cutlass::epilogue::threadblock::OutputTileShape<256, 8, 2,
                                                                    1, 1>,
                    cutlass::epilogue::threadblock::OutputTileShape<1, 8, 2, 1,
                                                                    8>,
                    kThreads, kElementsPerAccess,
                    cutlass::sizeof_bits<Element>::value>;

#else
    using InternalThreadMap =
            cutlass::transform::PitchLinearStripminedThreadMap<
                    cutlass::layout::PitchLinearShape<256, 128>, kThreads,
                    kElementsPerAccess>;

    using Shape =
            cutlass::epilogue::threadblock::OutputTileShape<256,  // column
                                                            8,    // row
                                                            2,    // group
                                                            1,    // cluster
                                                            8     // tile
                                                            >;

    using Iterations =
            cutlass::epilogue::threadblock::OutputTileShape<1,  // column
                                                            2,  // row
                                                            2,  // group
                                                            1,  // cluster
                                                            1   // iterations
                                                            >;

    using Delta = cutlass::epilogue::threadblock::OutputTileShape<1,   // column
                                                                  4,   // row
                                                                  64,  // group
                                                                  1,  // cluster
                                                                  1   // tile
                                                                  >;

    using Count =
            cutlass::epilogue::threadblock::OutputTileShape<1,  // column
                                                            8,  // row
                                                            1,  // group
                                                            1,  // cluster
                                                            8   // iterations
                                                            >;

    using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap<
            InternalThreadMap, Shape, Iterations, Delta, Count>;
#endif

    using PredicatedTileIterator =
            cutlass::epilogue::threadblock::PredicatedTileIterator<ThreadMap,
                                                                   Element>;

    //
    // Initialize workspace
    //
    cutlass::MatrixCoord tensor_extent{128, 256};
    cutlass::MatrixCoord output_extent{123, 252};

    //
    // Configure parameters
    //

    cutlass::HostTensor<Element, Layout> host_tensor(tensor_extent);

    typename PredicatedTileIterator::Params iterator_params(
            host_tensor.layout());

    host_tensor.sync_device();

    //
    // Launch kernel
    //

    dim3 grid(1, 1);
    dim3 block(kThreads, 1);

    test::epilogue::threadblock::kernel_store_iterator<PredicatedTileIterator>
            <<<grid, block>>>(iterator_params, host_tensor.device_ref(),
                              output_extent);

    cudaError_t result = cudaDeviceSynchronize();
    ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result);

    //
    // Verify results
    //

    host_tensor.sync_host();

    bool passed = verify_footprint(host_tensor.host_view(), output_extent);
    EXPECT_TRUE(passed);

    if (!passed) {
        std::ofstream output("tensor_op_128x256x32_64x64x8.csv");
        output << host_tensor.host_view();
    }
}

/////////////////////////////////////////////////////////////////////////////////////////////////

TEST(PredicatedTileIterator, volta_tensor_op_64x64x32_64x64x4) {
    using Layout = cutlass::layout::RowMajor;
    using Element = int;

    static int const kElementsPerAccess =
            128 / cutlass::sizeof_bits<Element>::value;
    static int const kThreads = 32;

#if 1

    using ThreadMap =
            cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<
                    cutlass::epilogue::threadblock::OutputTileShape<64, 2, 4, 1,
                                                                    1>,
                    cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1,
                                                                    8>,
                    kThreads, kElementsPerAccess,
                    cutlass::sizeof_bits<Element>::value>;

#else
    using InternalThreadMap =
            cutlass::transform::PitchLinearStripminedThreadMap<
                    cutlass::layout::PitchLinearShape<64, 8>, kThreads,
                    kElementsPerAccess>;

    using Shape =
            cutlass::epilogue::threadblock::OutputTileShape<64,  // column
                                                            2,   // row
                                                            4,   // group
                                                            1,   // cluster
                                                            8    // iterations
                                                            >;

    using Iterations =
            cutlass::epilogue::threadblock::OutputTileShape<1,  // column
                                                            1,  // row
                                                            4,  // group
                                                            1,  // cluster
                                                            1   // iterations
                                                            >;

    using Delta =
            cutlass::epilogue::threadblock::OutputTileShape<1,  // column
                                                            1,  // row
                                                            8,  // group
                                                            1,  // cluster
                                                            1   // iterations
                                                            >;

    using Count =
            cutlass::epilogue::threadblock::OutputTileShape<1,  // column
                                                            4,  // row
                                                            2,  // group
                                                            1,  // cluster
                                                            8   // iterations
                                                            >;

    using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap<
            InternalThreadMap, Shape, Iterations, Delta, Count>;
#endif

    using PredicatedTileIterator =
            cutlass::epilogue::threadblock::PredicatedTileIterator<ThreadMap,
                                                                   Element>;

    //
    // Initialize workspace
    //
    cutlass::MatrixCoord tensor_extent{64, 64};
    cutlass::MatrixCoord output_extent{62, 56};

    //
    // Configure parameters
    //

    cutlass::HostTensor<Element, Layout> host_tensor(tensor_extent);

    typename PredicatedTileIterator::Params iterator_params(
            host_tensor.layout());

    host_tensor.sync_device();

    //
    // Launch kernel
    //

    dim3 grid(1, 1);
    dim3 block(kThreads, 1);

    test::epilogue::threadblock::kernel_store_iterator<PredicatedTileIterator>
            <<<grid, block>>>(iterator_params, host_tensor.device_ref(),
                              output_extent);

    cudaError_t result = cudaDeviceSynchronize();
    ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result);

    //
    // Verify results
    //
    host_tensor.sync_host();

    bool passed = verify_footprint(host_tensor.host_view(), output_extent);
    EXPECT_TRUE(passed);

    if (!passed) {
        std::ofstream output("volta_tensor_op_64x64x32_64x64x4.csv");
        output << host_tensor.host_view();
    }
}

///////////////////////////////////////////////////////////////////////////////////////////////////

TEST(PredicatedTileIterator, volta_tensor_op_64x128x32_32x64x4) {
    using Layout = cutlass::layout::RowMajor;
    using Element = int;

    static int const kElementsPerAccess =
            128 / cutlass::sizeof_bits<Element>::value;
    static int const kThreads = 128;

#if 1

    using ThreadMap =
            cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<
                    cutlass::epilogue::threadblock::OutputTileShape<128, 2, 4,
                                                                    1, 1>,
                    cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1,
                                                                    8>,
                    kThreads, kElementsPerAccess,
                    cutlass::sizeof_bits<Element>::value>;

#else
    using InternalThreadMap =
            cutlass::transform::PitchLinearStripminedThreadMap<
                    cutlass::layout::PitchLinearShape<128, 8>, kThreads,
                    kElementsPerAccess>;

    using Shape =
            cutlass::epilogue::threadblock::OutputTileShape<128,  // column
                                                            2,    // row
                                                            2,    // group
                                                            2,    // cluster
                                                            8     // iterations
                                                            >;

    using Iterations =
            cutlass::epilogue::threadblock::OutputTileShape<1,  // column
                                                            1,  // row
                                                            1,  // group
                                                            2,  // cluster
                                                            1   // iterations
                                                            >;

    using Delta =
            cutlass::epilogue::threadblock::OutputTileShape<1,   // column
                                                            1,   // row
                                                            8,   // group
                                                            32,  // cluster
                                                            1    // iterations
                                                            >;

    using Count =
            cutlass::epilogue::threadblock::OutputTileShape<1,  // column
                                                            4,  // row
                                                            4,  // group
                                                            1,  // cluster
                                                            8   // iterations
                                                            >;

    using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap<
            InternalThreadMap, Shape, Iterations, Delta, Count>;
#endif

    using PredicatedTileIterator =
            cutlass::epilogue::threadblock::PredicatedTileIterator<ThreadMap,
                                                                   Element>;

    //
    // Initialize workspace
    //
    cutlass::MatrixCoord tensor_extent{64, 128};
    cutlass::MatrixCoord output_extent{57, 124};

    //
    // Configure parameters
    //

    cutlass::HostTensor<Element, Layout> host_tensor(tensor_extent);

    typename PredicatedTileIterator::Params iterator_params(
            host_tensor.layout());

    host_tensor.sync_device();

    //
    // Launch kernel
    //

    dim3 grid(1, 1);
    dim3 block(kThreads, 1);

    test::epilogue::threadblock::kernel_store_iterator<PredicatedTileIterator>
            <<<grid, block>>>(iterator_params, host_tensor.device_ref(),
                              output_extent);

    cudaError_t result = cudaDeviceSynchronize();
    ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result);

    //
    // Verify results
    //

    host_tensor.sync_host();

    bool passed = verify_footprint(host_tensor.host_view(), output_extent);
    EXPECT_TRUE(passed);

    if (!passed) {
        std::ofstream output("volta_tensor_op_64x128x32_32x64x4.csv");
        output << host_tensor.host_view();
    }
}

/////////////////////////////////////////////////////////////////////////////////////////////////

TEST(PredicatedTileIterator, volta_tensor_op_128x256x32_64x64x4) {
    using Layout = cutlass::layout::RowMajor;
    using Element = int;

    static int const kElementsPerAccess =
            128 / cutlass::sizeof_bits<Element>::value;
    static int const kThreads = 256;

#if 1

    using ThreadMap =
            cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<
                    cutlass::epilogue::threadblock::OutputTileShape<256, 2, 4,
                                                                    2, 1>,
                    cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1,
                                                                    8>,
                    kThreads, kElementsPerAccess,
                    cutlass::sizeof_bits<Element>::value>;

#else
    using InternalThreadMap =
            cutlass::transform::PitchLinearStripminedThreadMap<
                    cutlass::layout::PitchLinearShape<256, 16>, kThreads,
                    kElementsPerAccess>;

    using Shape =
            cutlass::epilogue::threadblock::OutputTileShape<256,  // column
                                                            2,    // row
                                                            4,    // group
                                                            2,    // cluster
                                                            8     // iterations
                                                            >;

    using Iterations =
            cutlass::epilogue::threadblock::OutputTileShape<1,  // column
                                                            1,  // row
                                                            2,  // group
                                                            2,  // cluster
                                                            1   // iterations
                                                            >;

    using Delta =
            cutlass::epilogue::threadblock::OutputTileShape<1,   // column
                                                            1,   // row
                                                            16,  // group
                                                            64,  // cluster
                                                            1    // iterations
                                                            >;

    using Count =
            cutlass::epilogue::threadblock::OutputTileShape<1,  // column
                                                            4,  // row
                                                            2,  // group
                                                            1,  // cluster
                                                            8   // iterations
                                                            >;

    using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap<
            InternalThreadMap, Shape, Iterations, Delta, Count>;
#endif

    using PredicatedTileIterator =
            cutlass::epilogue::threadblock::PredicatedTileIterator<ThreadMap,
                                                                   Element>;

    //
    // Initialize workspace
    //
    cutlass::MatrixCoord tensor_extent{128, 256};
    cutlass::MatrixCoord output_extent{128, 256};

    //
    // Configure parameters
    //

    cutlass::HostTensor<Element, Layout> host_tensor(tensor_extent);

    typename PredicatedTileIterator::Params iterator_params(
            host_tensor.layout());

    host_tensor.sync_device();

    //
    // Launch kernel
    //

    dim3 grid(1, 1);
    dim3 block(kThreads, 1);

    test::epilogue::threadblock::kernel_store_iterator<PredicatedTileIterator>
            <<<grid, block>>>(iterator_params, host_tensor.device_ref(),
                              output_extent);

    cudaError_t result = cudaDeviceSynchronize();
    ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result);

    //
    // Verify results
    //

    host_tensor.sync_host();

    bool passed = verify_footprint(host_tensor.host_view(), output_extent);
    EXPECT_TRUE(passed);

    if (!passed || true) {
        std::ofstream output("volta_tensor_op_128x256x32_64x64x4.csv");
        output << host_tensor.host_view();
    }
}

TEST(PredicatedTileIterator, volta_tensor_op_256x128x32_64x64x4) {
    using Layout = cutlass::layout::RowMajor;
    using Element = int;

    static int const kElementsPerAccess =
            128 / cutlass::sizeof_bits<Element>::value;
    static int const kThreads = 256;

    using ThreadMap =
            cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<
                    cutlass::epilogue::threadblock::OutputTileShape<128, 2, 4,
                                                                    4, 1>,
                    cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1,
                                                                    8>,
                    kThreads, kElementsPerAccess,
                    cutlass::sizeof_bits<Element>::value>;

    using PredicatedTileIterator =
            cutlass::epilogue::threadblock::PredicatedTileIterator<ThreadMap,
                                                                   Element>;

    //
    // Initialize workspace
    //
    cutlass::MatrixCoord tensor_extent{256, 128};
    cutlass::MatrixCoord output_extent{256, 128};

    //
    // Configure parameters
    //

    cutlass::HostTensor<Element, Layout> host_tensor(tensor_extent);

    typename PredicatedTileIterator::Params iterator_params(
            host_tensor.layout());

    host_tensor.sync_device();

    //
    // Launch kernel
    //

    dim3 grid(1, 1);
    dim3 block(kThreads, 1);

    test::epilogue::threadblock::kernel_store_iterator<PredicatedTileIterator>
            <<<grid, block>>>(iterator_params, host_tensor.device_ref(),
                              output_extent);

    cudaError_t result = cudaDeviceSynchronize();
    ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result);

    //
    // Verify results
    //

    host_tensor.sync_host();

    bool passed = verify_footprint(host_tensor.host_view(), output_extent);
    EXPECT_TRUE(passed);

    if (!passed || true) {
        std::ofstream output("volta_tensor_op_256x128x32_64x64x4.csv");
        output << host_tensor.host_view();
    }
}

/////////////////////////////////////////////////////////////////////////////////////////////////

TEST(PredicatedTileIterator, simt_32x64x8_32x64x1) {
    using Layout = cutlass::layout::RowMajor;
    using Element = int;

    static int const kElementsPerAccess =
            32 / cutlass::sizeof_bits<Element>::value;
    static int const kThreads = 32;

#if 1

    using ThreadMap =
            cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<
                    cutlass::epilogue::threadblock::OutputTileShape<64, 1, 4, 1,
                                                                    1>,
                    cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1,
                                                                    8>,
                    kThreads, kElementsPerAccess,
                    cutlass::sizeof_bits<Element>::value>;

#else
    using InternalThreadMap =
            cutlass::transform::PitchLinearStripminedThreadMap<
                    cutlass::layout::PitchLinearShape<64, 4>, kThreads,
                    kElementsPerAccess>;

    using Shape =
            cutlass::epilogue::threadblock::OutputTileShape<64,  // column
                                                            1,   // row
                                                            4,   // group
                                                            1,   // cluster
                                                            1    // iterations
                                                            >;

    using Iterations =
            cutlass::epilogue::threadblock::OutputTileShape<2,  // column
                                                            1,  // row
                                                            4,  // group
                                                            1,  // cluster
                                                            1   // iterations
                                                            >;

    using Delta =
            cutlass::epilogue::threadblock::OutputTileShape<32,  // column
                                                            1,   // row
                                                            4,   // group
                                                            16,  // cluster
                                                            1    // iterations
                                                            >;

    using Count =
            cutlass::epilogue::threadblock::OutputTileShape<1,  // column
                                                            4,  // row
                                                            2,  // group
                                                            1,  // cluster
                                                            8   // iterations
                                                            >;

    using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap<
            InternalThreadMap, Shape, Iterations, Delta, Count>;
#endif

    using PredicatedTileIterator =
            cutlass::epilogue::threadblock::PredicatedTileIterator<ThreadMap,
                                                                   Element>;

    //
    // Initialize workspace
    //
    cutlass::MatrixCoord tensor_extent{32, 64};
    cutlass::MatrixCoord output_extent{27, 63};

    //
    // Configure parameters
    //

    cutlass::HostTensor<Element, Layout> host_tensor(tensor_extent);

    typename PredicatedTileIterator::Params iterator_params(
            host_tensor.layout());

    host_tensor.sync_device();

    //
    // Launch kernel
    //

    dim3 grid(1, 1);
    dim3 block(kThreads, 1);

    test::epilogue::threadblock::kernel_store_iterator<PredicatedTileIterator>
            <<<grid, block>>>(iterator_params, host_tensor.device_ref(),
                              output_extent);

    cudaError_t result = cudaDeviceSynchronize();
    ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result);

    //
    // Verify results
    //

    host_tensor.sync_host();

    bool passed = verify_footprint(host_tensor.host_view(), output_extent);
    EXPECT_TRUE(passed);

    if (!passed) {
        std::ofstream output("simt_32x64x8_32x64x1.csv");
        output << host_tensor.host_view();
    }
}

/////////////////////////////////////////////////////////////////////////////////////////////////

TEST(PredicatedTileIterator, simt_128x128x8_32x64x1) {
    using Layout = cutlass::layout::RowMajor;
    using Element = int;

    static int const kElementsPerAccess =
            32 / cutlass::sizeof_bits<Element>::value;
    static int const kThreads = 256;

#if 1

    using ThreadMap =
            cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<
                    cutlass::epilogue::threadblock::OutputTileShape<128, 1, 4,
                                                                    4, 1>,
                    cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1,
                                                                    8>,
                    kThreads, kElementsPerAccess,
                    cutlass::sizeof_bits<Element>::value>;

#else
    using InternalThreadMap =
            cutlass::transform::PitchLinearStripminedThreadMap<
                    cutlass::layout::PitchLinearShape<128, 16>, kThreads,
                    kElementsPerAccess>;

    using Shape =
            cutlass::epilogue::threadblock::OutputTileShape<128,  // column
                                                            1,    // row
                                                            4,    // group
                                                            4,    // cluster
                                                            1     // iterations
                                                            >;

    using Iterations =
            cutlass::epilogue::threadblock::OutputTileShape<2,  // column
                                                            1,  // row
                                                            2,  // group
                                                            4,  // cluster
                                                            1   // iterations
                                                            >;

    using Delta =
            cutlass::epilogue::threadblock::OutputTileShape<32,  // column
                                                            1,   // row
                                                            8,   // group
                                                            32,  // cluster
                                                            1    // iterations
                                                            >;

    using Count =
            cutlass::epilogue::threadblock::OutputTileShape<1,  // column
                                                            4,  // row
                                                            2,  // group
                                                            1,  // cluster
                                                            8   // iterations
                                                            >;

    using ThreadMap = cutlass::epilogue::threadblock::OutputTileThreadMap<
            InternalThreadMap, Shape, Iterations, Delta, Count>;
#endif

    using PredicatedTileIterator =
            cutlass::epilogue::threadblock::PredicatedTileIterator<ThreadMap,
                                                                   Element>;

    //
    // Initialize workspace
    //
    cutlass::MatrixCoord tensor_extent{128, 128};
    cutlass::MatrixCoord output_extent{123, 121};

    //
    // Configure parameters
    //

    cutlass::HostTensor<Element, Layout> host_tensor(tensor_extent);

    typename PredicatedTileIterator::Params iterator_params(
            host_tensor.layout());

    host_tensor.sync_device();

    //
    // Launch kernel
    //

    dim3 grid(1, 1);
    dim3 block(kThreads, 1);

    test::epilogue::threadblock::kernel_store_iterator<PredicatedTileIterator>
            <<<grid, block>>>(iterator_params, host_tensor.device_ref(),
                              output_extent);

    cudaError_t result = cudaDeviceSynchronize();
    ASSERT_EQ(result, cudaSuccess) << cudaGetErrorString(result);

    //
    // Verify results
    //

    host_tensor.sync_host();

    bool passed = verify_footprint(host_tensor.host_view(), output_extent);
    EXPECT_TRUE(passed);

    if (!passed) {
        std::ofstream output("simt_128x128x8_32x64x1.csv");
        output << host_tensor.host_view();
    }
}

///////////////////////////////////////////////////////////////////////////////////////////////////
